Skip to content

Add optional Approximate Top-K configuration for MLA Indexer#4243

Merged
copybara-service[bot] merged 1 commit into
mainfrom
zjiahao/DSA3.2-approx-top-k
Jun 26, 2026
Merged

Add optional Approximate Top-K configuration for MLA Indexer#4243
copybara-service[bot] merged 1 commit into
mainfrom
zjiahao/DSA3.2-approx-top-k

Conversation

@JHCuc3m

@JHCuc3m JHCuc3m commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR adds an optional configuration parameter indexer_use_approx_top_k to the Multi-head Latent Attention (MLA) Indexer, allowing users to enable JAX's TPU-optimized approx_max_k primitive instead of the default exact top_k selection.

Why is this change being made?

During performance investigations of DeepSeek-V3.2 in long-context mode (128K sequence length) with 128-way Context Parallelism (CP=128), the top-k was identified a major bottleneck in the MLA Indexer forward pass:

  • The indexer computes a local score matrix of shape [1, 1024, 131072] on each device.
  • The default exact top_k selection introduces a massive step-time overhead for sorting 131K elements per layer across 58 layers.

Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95.

Why this is a good solution

JAX's approx_max_k employs block-based reduction optimized for TPU Matrix Units (MXU). It reduces complexity from $\mathcal{O}(N \log^2 N)$ to $\approx \mathcal{O}(N + K \log K)$ with a significantly smaller constant factor. Paper: https://arxiv.org/pdf/2206.14286.

Workload show a ~4x speedup on f32[1, 1024, 65536] tensors when tested on TPU with the approximate path enabled using a recall target of 0.95.

Specific Implementation Details

  1. Configuration Schema (types.py): Added indexer_use_approx_top_k and indexer_approx_top_k_recall to the AttentionIndexer Pydantic class to pass configuration validation.
  2. Default Config (base.yml): Exposed the parameters with safe defaults (indexer_use_approx_top_k: false, indexer_approx_top_k_recall: 0.95).
  3. Model Architecture (attention_mla.py): Updated Indexer.__call__ to conditionally route the selection to jax.lax.approx_max_k when enabled.

Shortcomings & Future Improvements

  • There is not systematic study on how the accuracy lost of using indexer_use_approx_top_k instead of top_k might affect downstream model performance, while it is expected to be minimal when a high recall rate is used.

Tests

1. Regression Guard (Default Path)

We ran the attention unit test suite with the default configuration (indexer_use_approx_top_k=false) to ensure no regressions:

  • Command: pytest tests/unit/attention_test.py
  • Result: PASSED (20 passed, 32 skipped).

2. Compilation & Tracing Safety

We added a new unit test, test_indexer_with_approx_top_k, to verify that the new path compiles and traces successfully in JAX:

  • Command: pytest tests/unit/attention_test.py -k test_indexer_with_approx_top_k
  • Result: PASSED.

3. Mathematical Correctness & Recall Tracking

We added a correctness test, test_approx_top_k_recall, which generates random scores of shape [4, 16, 1024], runs both exact and approximate top-K ($K=64$), and calculates the actual recall:

  • Command: pytest tests/unit/attention_test.py -k test_approx_top_k_recall -s
  • Result: PASSED (Achieved 100% recall on CPU).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

@codecov

codecov Bot commented Jun 23, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@JHCuc3m JHCuc3m force-pushed the zjiahao/DSA3.2-approx-top-k branch 3 times, most recently from 9ef80f2 to 53d7917 Compare June 23, 2026 22:23
@JHCuc3m

JHCuc3m commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Code quality checker failed from existing file content

Comment thread tests/unit/attention_test.py Outdated

# Assert that the actual recall is close to or exceeds the target.
# We allow a small margin (e.g., 0.05) due to the approximate nature and sample size.
self.assertGreaterEqual(mean_recall, recall_target - 0.05)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 0.05 too large? What about making it 0.01?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

@NuojCheng NuojCheng left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a very cool optimization! Thank you Jiahao

@JHCuc3m JHCuc3m force-pushed the zjiahao/DSA3.2-approx-top-k branch from e4cc351 to 76dae1b Compare June 26, 2026 21:47
@copybara-service copybara-service Bot merged commit 2d1d53e into main Jun 26, 2026
53 of 56 checks passed
@copybara-service copybara-service Bot deleted the zjiahao/DSA3.2-approx-top-k branch June 26, 2026 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants